function output=osc(Xcal,Ycal,nOSC, Xtest)
%Orthogonal Signal Correction using Wold's method
%output=osc(Xcal,Ycal,nOSC, Xtest)
%Xcal, Ycal: the training data and concentration matrix
%nOSC: number of OSC components, default 1
%Xtest: the test data to be corrected
%reference:Tom Fearn, On orthogonal signal correction, 
%Chemometrics and Intelligent Laboratory Systems 50,2000. 47-52


[n, p] = size(Xcal);
% REMOVED: D = eye(n);  <-- Saves memory

% Check for singularity and pre-compute the projector part
% This replaces: Ycal*inv(Ycal'*Ycal)*Ycal'
if rank(Ycal) < size(Ycal, 2)
    Y_projector = Ycal * pinv(Ycal'*Ycal) * Ycal';
else
    Y_projector = Ycal / (Ycal'*Ycal) * Ycal';
end

tol = 1e-6;
SSQcal = sum(sum(Xcal.^2));
SSQtest=sum(sum(Xtest.^2));

for k = 1:nOSC
    [U, S, ~] = svds(Xcal, 1);
    t = U(:,1) * S(1,1); % Initial guess
    t_old = t;
    err = 1;
    
    while err > tol
        % Orthogonalize t with respect to Y
        % Originally: t = PROJ * t_old;
        % Optimized (No large D matrix):
        t = t_old - (Y_projector * t_old);
        
        [w] = plsnipals(Xcal, t, 1); % Calculate weights
        t = Xcal * w; % Update scores
        
        err = norm(t - t_old) / norm(t_old);
        t_old = t;
    end
    
    % Calculate loadings and subtract
    p = (Xcal' * t) / (t' * t);
    Xcal = Xcal - t * p';
    
    % Store
    W(:,k) = w;
    T(:,k) = t;
    P(:,k) = p;
end



if nargin<3
    nOSC=1;
end

if nargin<4
    Xtest=[];
end

% Resizing matrix for calculation
% Xcal = cast(Xcal,'single');
% Ycal = cast(Ycal,'single');
% warndlg('For memory purposes, the algorithm works correctly for a data matrix of no more than 16k points')
% uiwait(gcf)

% [n,~]=size(Xcal);
% D=eye(n);
% tol=0.000001;
% SSQcal=sum(sum(Xcal.^2));
% SSQtest=sum(sum(Xtest.^2));
% 
% PROJ=(D-Ycal*inv(Ycal'*Ycal)*Ycal'); %#ok<MINV> 
% % PROJ = D - Ycal/(Ycal'*Ycal)*Ycal';
% 
% for k=1:nOSC
%     [U,S,~]=svds(Xcal,1);
%     t_old=U(:,1)*S(1,1);
%     error=1;
%     while error>tol
%       t=PROJ*t_old;
%       [w]=plsnipals(Xcal,t,1);
%       t=Xcal*w;
%       error=norm(t-t_old)/norm(t_old);
%       t_old=t;
%     end
%     p=Xcal'*t/(t'*t);
%     Xcal=Xcal-t*p';
%     W(:,k)=w;
%     T(:,k)=t;
%     P(:,k)=p;   
% end

%ratio of explained variance of OSC components
R2Xcal=1-sum(sum(Xcal.^2))/SSQcal;
output.W_orth=W;
output.T_orth=T;
output.P_orth=P;
output.Xcal=Xcal;
output.R2Xcal=R2Xcal;

%Correcting new samples
if ~isempty(Xtest)
    for i=1:nOSC
     t=Xtest*W(:,i);
     Xtest=Xtest-t*P(:,i)';
    end
    Ztest=Xtest;
    output.Xtest=Ztest;
    R2Xtest=1-sum(sum(Ztest.^2))/SSQtest;
    output.R2Xtest=R2Xtest;
end

%---------------------------------------------------
function [reg,i] = somesimpls(x,y,tol)
    %return SIMPLS weights which correspond to a total variance of "tol"
    %I/O: [weights,i] = sim(x,y,tol)
    
    x = cast(x,'single');
    y = cast(y,'single');

    s      = times(x',y);
    totvar = sum(sum(x.^2));
    total  = 0;
    i      = 0;
    
    while total < tol
      
      i  = i+1;
    
      rr = s;           %weights from covar.
      tt = x*rr;
      normtt = norm(tt);
      tt = tt/normtt;
      rr = rr/normtt;
      pp = (tt'*x)';
      
      qq = y'*tt;
      uu = y*qq;
      vv = pp;
      if i > 1
        vv = vv - basis*(basis'*pp);
        uu = uu - loads{1,1}*(loads{1,1}'*uu);
      end
      vv = vv/norm(vv);
      s  = s - vv*(vv'*s);
       
      total = total + (pp'*pp)/totvar*100;
      
      wts(:,i)        = rr;           % x-block weights
      loads{1,1}(:,i) = tt;           % x-block scores
      loads{2,2}(:,i) = qq;           % y-block loadings
      basis(:,i)      = vv;           % basis of x-loadings
      
    end
    
    if i > 1
      reg = sum((wts*diag(loads{2,2}))');
    else
      reg = (wts*loads{2,2})';
    end
end

%---------------------------------------------------
function [w]=plsnipals(X,Y,A)
    varX=sum(sum(X.^2));
    varY=sum(sum(Y.^2));
    for i=1:A
        error=1;
        u=Y(:,1);
        niter=0;
        while (error>1e-8 && niter<1000)  % for convergence test
            w=X'*u/(u'*u);
            w=w/norm(w);
            t=X*w;
            q=Y'*t/(t'*t);  % regress Y against t;
            u1=Y*q/(q'*q);
            error=norm(u1-u)/norm(u);
            u=u1;
            niter=niter+1;
        end
        p=X'*t/(t'*t);
        X=X-t*p';
        Y=Y-t*q';
        
        %+++ store
        W(:,i)=w;
        T(:,i)=t;
        P(:,i)=p;
        Q(:,i)=q;
        
    end
    
    %+++ calculate explained variance
    R2X=diag(T'*T*P'*P)/varX;
    R2Y=diag(T'*T*Q'*Q)/varY;
    
    Wstar=W*(P'*W)^(-1); 
    B=Wstar*Q';
    Q=Q';
end

end